{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "275415f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip installs\n",
    "\n",
    "!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
    "!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "535bd9de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "import re\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from google.colab import userdata\n",
    "from huggingface_hub import login\n",
    "import torch\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, mean_absolute_percentage_error\n",
    "import torch.nn.functional as F\n",
    "import transformers\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n",
    "from datasets import load_dataset, Dataset, DatasetDict\n",
    "from datetime import datetime\n",
    "from peft import PeftModel\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc58234a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Constants\n",
    "\n",
    "BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
    "PROJECT_NAME = \"pricer\"\n",
    "HF_USER = \"ed-donner\"\n",
    "RUN_NAME = \"2024-09-13_13.04.39\"\n",
    "PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
    "REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n",
    "FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
    "\n",
    "\n",
    "DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
    "# Or just use the one I've uploaded\n",
    "# DATASET_NAME = \"ed-donner/pricer-data\"\n",
    "\n",
    "# Hyperparameters for QLoRA\n",
    "\n",
    "QUANT_4_BIT = True\n",
    "top_K = 6\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "# Used for writing to output in color\n",
    "\n",
    "GREEN = \"\\033[92m\"\n",
    "YELLOW = \"\\033[93m\"\n",
    "RED = \"\\033[91m\"\n",
    "RESET = \"\\033[0m\"\n",
    "COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0145ad8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Log in to HuggingFace\n",
    "\n",
    "hf_token = userdata.get('HF_TOKEN')\n",
    "login(hf_token, add_to_git_credential=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6919506e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(DATASET_NAME)\n",
    "train_full = dataset['train']\n",
    "test_full = dataset['test']\n",
    "\n",
    "# TRAIN_SIZE = len(train_full)\n",
    "# TEST_SIZE = len(test_full)\n",
    "\n",
    "TRAIN_SIZE = 8000  # Very small for testing\n",
    "TEST_SIZE = 2000    # Very small for testing\n",
    "\n",
    "train = train_full.select(range(min(TRAIN_SIZE, len(train_full))))\n",
    "test = test_full.select(range(min(TEST_SIZE, len(test_full))))\n",
    "\n",
    "print(f\"Using small test dataset:\")\n",
    "print(f\"  Train samples: {len(train)} (full dataset has {len(train_full)})\")\n",
    "print(f\"  Test samples: {len(test)} (full dataset has {len(test_full)})\")\n",
    "print(f\"\\nTo use full dataset, set TRAIN_SIZE and TEST_SIZE to None or large numbers\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea79cde1",
   "metadata": {},
   "outputs": [],
   "source": [
    "if QUANT_4_BIT:\n",
    "  quant_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    bnb_4bit_quant_type=\"nf4\"\n",
    "  )\n",
    "else:\n",
    "  quant_config = BitsAndBytesConfig(\n",
    "    load_in_8bit=True,\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef108f8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the Tokenizer and the Model\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.padding_side = \"right\"\n",
    "\n",
    "base_model = AutoModelForCausalLM.from_pretrained(\n",
    "    BASE_MODEL,\n",
    "    quantization_config=quant_config,\n",
    "    device_map=\"auto\",\n",
    ")\n",
    "base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
    "\n",
    "# Load the fine-tuned model with PEFT\n",
    "if REVISION:\n",
    "    fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n",
    "else:\n",
    "    fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n",
    "\n",
    "fine_tuned_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f3c4176",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_price(s):\n",
    "    if \"Price is $\" in s:\n",
    "      contents = s.split(\"Price is $\")[1]\n",
    "      contents = contents.replace(',','')\n",
    "      match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", contents)\n",
    "      return float(match.group()) if match else 0\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "436fa29a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Original prediction function takes the most likely next token\n",
    "\n",
    "def model_predict(prompt):\n",
    "    set_seed(42)\n",
    "    inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "    attention_mask = torch.ones(inputs.shape, device=\"cuda\")\n",
    "    outputs = fine_tuned_model.generate(inputs, attention_mask=attention_mask, max_new_tokens=3, num_return_sequences=1)\n",
    "    response = tokenizer.decode(outputs[0])\n",
    "    return extract_price(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a666dab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def improved_model_predict(prompt, device=\"cuda\"):\n",
    "    set_seed(42)\n",
    "    inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
    "    attention_mask = torch.ones(inputs.shape, device=device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n",
    "        next_token_logits = outputs.logits[:, -1, :].to('cpu')\n",
    "\n",
    "    next_token_probs = F.softmax(next_token_logits, dim=-1)\n",
    "    top_prob, top_token_id = next_token_probs.topk(top_K)\n",
    "    prices, weights = [], []\n",
    "    for i in range(top_K):\n",
    "      predicted_token = tokenizer.decode(top_token_id[0][i])\n",
    "      probability = top_prob[0][i]\n",
    "      try:\n",
    "        result = float(predicted_token)\n",
    "      except ValueError as e:\n",
    "        result = 0.0\n",
    "      if result > 0:\n",
    "        prices.append(result)\n",
    "        weights.append(probability)\n",
    "    if not prices:\n",
    "      return 0.0, 0.0\n",
    "    total = sum(weights)\n",
    "    weighted_prices = [price * weight / total for price, weight in zip(prices, weights)]\n",
    "    return sum(weighted_prices).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9664c4c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class Tester:\n",
    "\n",
    "    def __init__(self, predictor, data, title=None, show_progress=True):\n",
    "        self.predictor = predictor\n",
    "        self.data = data\n",
    "        self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
    "        self.size = len(data)\n",
    "        self.guesses, self.truths, self.errors, self.rel_errors, self.sles, self.colors = [], [], [], [], [], []\n",
    "        self.show_progress = show_progress\n",
    "\n",
    "    def color_for(self, error, truth):\n",
    "        if error < 40 or error / truth < 0.2:\n",
    "            return \"green\"\n",
    "        elif error < 80 or error / truth < 0.4:\n",
    "            return \"orange\"\n",
    "        else:\n",
    "            return \"red\"\n",
    "\n",
    "    def run_datapoint(self, i):\n",
    "        datapoint = self.data[i]\n",
    "        guess = self.predictor(datapoint[\"text\"])\n",
    "        truth = datapoint[\"price\"]\n",
    "\n",
    "        error = guess - truth\n",
    "        abs_error = abs(error)\n",
    "        rel_error = abs_error / truth if truth != 0 else 0\n",
    "        log_error = math.log(truth + 1) - math.log(guess + 1)\n",
    "        sle = log_error ** 2\n",
    "        color = self.color_for(abs_error, truth)\n",
    "\n",
    "        title = (datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\") if \"\\n\\n\" in datapoint[\"text\"] else datapoint[\"text\"][:20]\n",
    "        self.guesses.append(guess)\n",
    "        self.truths.append(truth)\n",
    "        self.errors.append(error)\n",
    "        self.rel_errors.append(rel_error)\n",
    "        self.sles.append(sle)\n",
    "        self.colors.append(color)\n",
    "\n",
    "        print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} \"\n",
    "              f\"Error: ${abs_error:,.2f} RelErr: {rel_error*100:.1f}% SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
    "\n",
    "    def chart_all(self, chart_title):\n",
    "        \"\"\"Compact version: 4 performance charts in one grid.\"\"\"\n",
    "        t, g = np.array(self.truths), np.array(self.guesses)\n",
    "        rel_err, abs_err = np.array(self.rel_errors) * 100, np.abs(np.array(self.errors))\n",
    "\n",
    "        fig, axs = plt.subplots(2, 2, figsize=(14, 10))\n",
    "        fig.suptitle(f\"Performance Dashboard — {chart_title}\", fontsize=16, fontweight=\"bold\")\n",
    "\n",
    "        # Scatter plot\n",
    "        max_val = max(t.max(), g.max()) * 1.05\n",
    "        axs[1, 1].plot([0, max_val], [0, max_val], \"b--\", alpha=0.6)\n",
    "        axs[1, 1].scatter(t, g, s=20, c=self.colors, alpha=0.6)\n",
    "        axs[1, 1].set_title(\"Predictions vs Ground Truth\")\n",
    "        axs[1, 1].set_xlabel(\"True Price ($)\")\n",
    "        axs[1, 1].set_ylabel(\"Predicted ($)\")\n",
    "\n",
    "        # Accuracy by price range\n",
    "        bins = np.linspace(t.min(), t.max(), 6)\n",
    "        labels = [f\"${bins[i]:.0f}–${bins[i+1]:.0f}\" for i in range(len(bins)-1)]\n",
    "        inds = np.digitize(t, bins) - 1\n",
    "        avg_err = [rel_err[inds == i].mean() for i in range(len(labels))]\n",
    "        axs[0, 0].bar(labels, avg_err, color=\"seagreen\", alpha=0.8)\n",
    "        axs[0, 0].set_title(\"Avg Relative Error by Price Range\")\n",
    "        axs[0, 0].set_ylabel(\"Relative Error (%)\")\n",
    "        axs[0, 0].tick_params(axis=\"x\", rotation=30)\n",
    "\n",
    "        # Relative error distribution\n",
    "        axs[0, 1].hist(rel_err, bins=25, color=\"mediumpurple\", edgecolor=\"black\", alpha=0.7)\n",
    "        axs[0, 1].set_title(\"Relative Error Distribution (%)\")\n",
    "        axs[0, 1].set_xlabel(\"Relative Error (%)\")\n",
    "\n",
    "        # Absolute error distribution\n",
    "        axs[1, 0].hist(abs_err, bins=25, color=\"steelblue\", edgecolor=\"black\", alpha=0.7)\n",
    "        axs[1, 0].axvline(abs_err.mean(), color=\"red\", linestyle=\"--\", label=f\"Mean={abs_err.mean():.2f}\")\n",
    "        axs[1, 0].set_title(\"Absolute Error Distribution\")\n",
    "        axs[1, 0].set_xlabel(\"Absolute Error ($)\")\n",
    "        axs[1, 0].legend()\n",
    "\n",
    "        for ax in axs.ravel():\n",
    "            ax.grid(alpha=0.3)\n",
    "\n",
    "        plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "        plt.show()\n",
    "\n",
    "    def report(self):\n",
    "        y_true = np.array(self.truths)\n",
    "        y_pred = np.array(self.guesses)\n",
    "\n",
    "        mae = mean_absolute_error(y_true, y_pred)\n",
    "        rmse = math.sqrt(mean_squared_error(y_true, y_pred))\n",
    "        rmsle = math.sqrt(sum(self.sles) / self.size)\n",
    "        mape = mean_absolute_percentage_error(y_true, y_pred) * 100\n",
    "        median_error = float(np.median(np.abs(y_true - y_pred)))\n",
    "        r2 = r2_score(y_true, y_pred)\n",
    "\n",
    "        hit_rate_green = sum(1 for c in self.colors if c == \"green\") / self.size * 100\n",
    "        hit_rate_acceptable = sum(1 for c in self.colors if c in (\"green\", \"orange\")) / self.size * 100\n",
    "\n",
    "        print(f\"\\n{'='*70}\")\n",
    "        print(f\"FINAL REPORT: {self.title}\")\n",
    "        print(f\"{'='*70}\")\n",
    "        print(f\"Total Predictions: {self.size}\")\n",
    "        print(f\"\\n--- Error Metrics ---\")\n",
    "        print(f\"Mean Absolute Error (MAE): ${mae:,.2f}\")\n",
    "        print(f\"Median Error: ${median_error:,.2f}\")\n",
    "        print(f\"Root Mean Squared Error (RMSE): ${rmse:,.2f}\")\n",
    "        print(f\"Root Mean Squared Log Error (RMSLE): {rmsle:.4f}\")\n",
    "        print(f\"Mean Absolute Percentage Error (MAPE): {mape:.2f}%\")\n",
    "        print(f\"\\n--- Accuracy Metrics ---\")\n",
    "        print(f\"R² Score: {r2:.4f}\")\n",
    "        print(f\"Hit Rate (Green): {hit_rate_green:.1f}%\")\n",
    "        print(f\"Hit Rate (Green+Orange): {hit_rate_acceptable:.1f}%\")\n",
    "        print(f\"{'='*70}\\n\")\n",
    "        chart_title = f\"{self.title} | MAE=${mae:,.2f} | RMSLE={rmsle:.3f} | R²={r2:.3f}\"\n",
    "\n",
    "        self.chart_all(chart_title)\n",
    "\n",
    "    def run(self):\n",
    "        iterator = tqdm(range(self.size), desc=\"Testing Model\") if self.show_progress else range(self.size)\n",
    "        for i in iterator:\n",
    "            self.run_datapoint(i)\n",
    "        self.report()\n",
    "\n",
    "    @classmethod\n",
    "    def test(cls, function, data, title=None):\n",
    "        cls(function, data, title=title).run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e60a696",
   "metadata": {},
   "outputs": [],
   "source": [
    "Tester.test(\n",
    "    improved_model_predict, \n",
    "    test, \n",
    "    title=\"ed-donner Fine-tuned [Base | Llama 3.1 8B] (Improved - Small Test Set)\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
